{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "(train_image,train_label),(test_image,test_label) = tf.keras.datasets.fashion_mnist.load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(60000, 28, 28)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_image.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(60000,)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_label.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 对数据进行归一化处理\n",
    "train_image = train_image/255\n",
    "test_image = test_image/255"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 建立模型\n",
    "model = tf.keras.Sequential()\n",
    "model.add(tf.keras.layers.Flatten(input_shape=(28,28)))\n",
    "model.add(tf.keras.layers.Dense(128,activation='relu'))\n",
    "model.add(tf.keras.layers.Dense(10,activation='softmax'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "flatten (Flatten)            (None, 784)               0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 128)               100480    \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 10)                1290      \n",
      "=================================================================\n",
      "Total params: 101,770\n",
      "Trainable params: 101,770\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 模型编译\n",
    "model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3\n",
      "1875/1875 [==============================] - 4s 2ms/step - loss: 0.5024 - acc: 0.8239\n",
      "Epoch 2/3\n",
      "1875/1875 [==============================] - 3s 2ms/step - loss: 0.3745 - acc: 0.8642\n",
      "Epoch 3/3\n",
      "1875/1875 [==============================] - 3s 2ms/step - loss: 0.3361 - acc: 0.8760\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x22018002b08>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 训练模型\n",
    "model.fit(train_image,train_label,epochs=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.37809574604034424, 0.8626000285148621]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 模型评价 verbose=1显示评价过程，verbose=0不显示\n",
    "model.evaluate(test_image,test_label,verbose=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 保存整个模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保存模型，参数为保存路径\n",
    "model.save('less_model.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用保存好的模型\n",
    "new_model = tf.keras.models.load_model('less_model.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "flatten (Flatten)            (None, 784)               0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 128)               100480    \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 10)                1290      \n",
      "=================================================================\n",
      "Total params: 101,770\n",
      "Trainable params: 101,770\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "new_model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.37809574604034424, 0.10790000110864639]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_model.evaluate(test_image,test_label,verbose=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 仅保存架构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 保存模型架构\n",
    "json_config = model.to_json()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'{\"class_name\": \"Sequential\", \"config\": {\"name\": \"sequential\", \"layers\": [{\"class_name\": \"InputLayer\", \"config\": {\"batch_input_shape\": [null, 28, 28], \"dtype\": \"float32\", \"sparse\": false, \"ragged\": false, \"name\": \"flatten_input\"}}, {\"class_name\": \"Flatten\", \"config\": {\"name\": \"flatten\", \"trainable\": true, \"batch_input_shape\": [null, 28, 28], \"dtype\": \"float32\", \"data_format\": \"channels_last\"}}, {\"class_name\": \"Dense\", \"config\": {\"name\": \"dense\", \"trainable\": true, \"dtype\": \"float32\", \"units\": 128, \"activation\": \"relu\", \"use_bias\": true, \"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"seed\": null}}, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"kernel_regularizer\": null, \"bias_regularizer\": null, \"activity_regularizer\": null, \"kernel_constraint\": null, \"bias_constraint\": null}}, {\"class_name\": \"Dense\", \"config\": {\"name\": \"dense_1\", \"trainable\": true, \"dtype\": \"float32\", \"units\": 10, \"activation\": \"softmax\", \"use_bias\": true, \"kernel_initializer\": {\"class_name\": \"GlorotUniform\", \"config\": {\"seed\": null}}, \"bias_initializer\": {\"class_name\": \"Zeros\", \"config\": {}}, \"kernel_regularizer\": null, \"bias_regularizer\": null, \"activity_regularizer\": null, \"kernel_constraint\": null, \"bias_constraint\": null}}]}, \"keras_version\": \"2.4.0\", \"backend\": \"tensorflow\"}'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "json_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 重建模型\n",
    "reinitialized_model = tf.keras.models.model_from_json(json_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"sequential\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "flatten (Flatten)            (None, 784)               0         \n",
      "_________________________________________________________________\n",
      "dense (Dense)                (None, 128)               100480    \n",
      "_________________________________________________________________\n",
      "dense_1 (Dense)              (None, 10)                1290      \n",
      "=================================================================\n",
      "Total params: 101,770\n",
      "Trainable params: 101,770\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "reinitialized_model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[2.4520509243011475, 0.1062999963760376]"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reinitialized_model.evaluate(test_image,test_label,verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "reinitialized_model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 仅保存模型权重"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取模型权重\n",
    "weights = model.get_weights()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 为模型设置权重\n",
    "reinitialized_model.set_weights(weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.37809574604034424, 0.8626000285148621]"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reinitialized_model.evaluate(test_image,test_label,verbose=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将模型权重保存到磁盘\n",
    "model.save_weights('less_weights.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 从磁盘中加载保存好的权重\n",
    "reinitialized_model.load_weights('less_weights.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.37809574604034424, 0.8626000285148621]"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reinitialized_model.evaluate(test_image,test_label,verbose=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 在训练期间保存检查点"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 配置保存检查点的回调函数\n",
    "cp_callback = tf.keras.callbacks.ModelCheckpoint('training_cp/cp.ckpt',save_weights_only=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 建立模型\n",
    "model = tf.keras.Sequential()\n",
    "model.add(tf.keras.layers.Flatten(input_shape=(28,28)))\n",
    "model.add(tf.keras.layers.Dense(128,activation='relu'))\n",
    "model.add(tf.keras.layers.Dense(10,activation='softmax'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 模型编译\n",
    "model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x220410b1f08>"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用保存好的检查点\n",
    "model.load_weights('training_cp/cp.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 0.5029 - acc: 0.8213\n",
      "Epoch 2/3\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 0.3758 - acc: 0.8645\n",
      "Epoch 3/3\n",
      "1875/1875 [==============================] - 5s 3ms/step - loss: 0.3358 - acc: 0.8771\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x22040dd2b08>"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 训练模型，指定保存检查点的回调函数\n",
    "model.fit(train_image,train_label,epochs=3,callbacks=[cp_callback])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 自定义训练中保存检查点"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.Sequential()\n",
    "model.add(tf.keras.layers.Flatten(input_shape=(28,28)))\n",
    "model.add(tf.keras.layers.Dense(128,activation='relu'))\n",
    "model.add(tf.keras.layers.Dense(10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = tf.keras.optimizers.Adam()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss(model,x,y):\n",
    "    y_ = model(x)\n",
    "    return loss_func(y,y_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loss = tf.keras.metrics.Mean('train_loss',dtype=tf.float32)\n",
    "train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')\n",
    "test_loss = tf.keras.metrics.Mean('test_loss',dtype=tf.float32)\n",
    "test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('test_accuracy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_step(model,images,labels):\n",
    "    with tf.GradientTape() as t:\n",
    "        pred = model(images)\n",
    "        loss_step = loss_func(labels,pred)\n",
    "    grads = t.gradient(loss_step,model.trainable_variables)\n",
    "    optimizer.apply_gradients(zip(grads,model.trainable_variables))\n",
    "    train_loss(loss_step)\n",
    "    train_accuracy(labels,pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.from_tensor_slices((train_image,train_label))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = dataset.shuffle(10000).batch(32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 指定检查点保存位置\n",
    "cp_dir = './customtrain_cp'\n",
    "cp_prefix = os.path.join(cp_dir,'ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义检查点，保存优化器和模型架构\n",
    "checkpoint = tf.train.Checkpoint(optimizer=optimizer,model=model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    for epoch in range(5):\n",
    "        for (batch,(images,labels)) in enumerate(dataset):\n",
    "            train_step(model,images,labels)\n",
    "        print('Epoch{} loss is {}'.format(epoch,train_loss.result()))\n",
    "        print('Epoch{} Accuracy is {}'.format(epoch,train_accuracy.result()))\n",
    "        train_loss.reset_states()\n",
    "        train_accuracy.reset_states()\n",
    "        if(epoch + 1)%2 == 0:\n",
    "            # 每两个epoch保存检查点\n",
    "            checkpoint.save(file_prefix = cp_prefix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch0 loss is 0.3365711569786072\n",
      "Epoch0 Accuracy is 0.8775333166122437\n",
      "Epoch1 loss is 0.31290820240974426\n",
      "Epoch1 Accuracy is 0.8865500092506409\n",
      "Epoch2 loss is 0.29701972007751465\n",
      "Epoch2 Accuracy is 0.8906333446502686\n",
      "Epoch3 loss is 0.2798936069011688\n",
      "Epoch3 Accuracy is 0.8957666754722595\n",
      "Epoch4 loss is 0.27014732360839844\n",
      "Epoch4 Accuracy is 0.8999166488647461\n"
     ]
    }
   ],
   "source": [
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'./customtrain_cp\\\\ckpt-2'"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 取出检查点目录下最新的检查点\n",
    "tf.train.latest_checkpoint(cp_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x22019fa1848>"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 从指定检查点恢复检查点\n",
    "checkpoint.restore(tf.train.latest_checkpoint(cp_dir))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
